from typing import List, Dict
import requests
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
import os
from config import Config

class LiteratureSearchTool:
    def __init__(self):
        self.api_base = Config.SEMANTIC_SCHOLAR_API_BASE
        
    def search_papers(self, query: str, limit: int = 10) -> List[Dict]:
        """搜索相关文献"""
        try:
            response = requests.get(
                f"{self.api_base}/paper/search",
                params={"query": query, "limit": limit}
            )
            if response.status_code == 200:
                papers = response.json()
                return [{
                    "title": paper["title"],
                    "authors": paper["authors"],
                    "year": paper.get("year"),
                    "abstract": paper.get("abstract"),
                    "url": paper.get("url")
                } for paper in papers["data"]]
            return []
        except Exception as e:
            print(f"文献搜索出错: {str(e)}")
            return []

class TextGenerationTool:
    def __init__(self, model_name: str = Config.OPENAI_MODEL):
        self.chat_model = ChatOpenAI(
            model_name=model_name,
            temperature=Config.OPENAI_TEMPERATURE,
            openai_api_key=Config.OPENAI_API_KEY,
            openai_api_base=Config.OPENAI_API_BASE
        )
        
    def generate(self, prompt: str, max_tokens: int = 1000) -> str:
        """生成文本内容"""
        try:
            messages = [
                SystemMessage(content="你是一个专业的学术论文写作助手。"),
                HumanMessage(content=prompt)
            ]
            response = self.chat_model.generate([messages])
            return response.generations[0][0].text
        except Exception as e:
            print(f"文本生成出错: {str(e)}")
            return ""

class GrammarCheckTool:
    def __init__(self):
        # 使用 LanguageTool API
        self.api_url = "https://api.languagetool.org/v2/check"
        
    def check_and_correct(self, text: str) -> str:
        """检查并修正语法错误"""
        try:
            response = requests.post(
                self.api_url,
                data={
                    "text": text,
                    "language": "en-US"
                }
            )
            if response.status_code == 200:
                result = response.json()
                corrected_text = text
                # 从后向前应用修改建议,避免位置偏移
                for match in reversed(result.get("matches", [])):
                    if "replacements" in match and match["replacements"]:
                        start = match["offset"]
                        end = start + match["length"]
                        replacement = match["replacements"][0]["value"]
                        corrected_text = corrected_text[:start] + replacement + corrected_text[end:]
                return corrected_text
            return text
        except Exception as e:
            print(f"语法检查出错: {str(e)}")
            return text

class FormattingTool:
    def format_paper(self, content: Dict, style: str = "APA") -> str:
        """按照指定格式排版论文"""
        try:
            formatted_content = []
            
            # 添加标题
            if "title" in content:
                formatted_content.append(f"# {content['title']}\n\n")
            
            # 添加作者信息
            if "authors" in content:
                formatted_content.append(", ".join(content["authors"]) + "\n\n")
            
            # 添加摘要
            if "abstract" in content:
                formatted_content.append("## Abstract\n\n")
                formatted_content.append(content["abstract"] + "\n\n")
            
            # 添加正文各部分
            sections = ["introduction", "methods", "results", "discussion", "conclusion"]
            for section in sections:
                if section in content:
                    formatted_content.append(f"## {section.title()}\n\n")
                    formatted_content.append(content[section] + "\n\n")
            
            # 添加参考文献
            if "references" in content:
                formatted_content.append("## References\n\n")
                for ref in content["references"]:
                    formatted_content.append(f"- {ref}\n")
            
            return "".join(formatted_content)
        except Exception as e:
            print(f"格式化出错: {str(e)}")
            return str(content) 